import os
import csv
from p_tqdm import p_map
from argparse import ArgumentParser
from copy import deepcopy

import numpy as np

from learner import SumUCB, GaussTS
from learner import UpUCB, UpUCB_L
from bandit import BernoulliUpliftBandit, interact


def main(params):

    n_arms = 20
    n_variables = 100000

    baseline_means = np.load(os.path.join(params.data_dir, 'baseline_means.npy'))
    treated_means = np.load(os.path.join(params.data_dir, 'treated_means.npy'))
    cluster_sizes = np.load(os.path.join(params.data_dir, 'cluster_sizes.npy'))

    cluster_sizes_cumsum = np.cumsum(cluster_sizes)
    affected_sets = []
    affected_sets.append(np.array(np.arange(cluster_sizes[0])))
    for k in range(n_arms-1):
        affected_sets.append(
            np.array(np.arange(cluster_sizes_cumsum[k], cluster_sizes_cumsum[k+1])))

    bandit = BernoulliUpliftBandit(n_arms, n_variables, cluster_sizes)

    bandit.affected_sets = affected_sets
    bandit.baseline = baseline_means

    for arm in range(n_arms):
        affected = np.ix_(affected_sets[arm])
        bandit.means[arm] = baseline_means
        bandit.means[arm, affected] = treated_means[affected]

    bandit.compute_statistics()

    filename = os.path.join(params.save_dir, params.algo)

    if params.use_baseline:
        filename = filename + '-with_baseline'
        baseline = bandit.baseline
    else:
        filename = filename + '-without_baseline'
        baseline = None

    filename = filename + f'-radius_{params.radius}'
    filename = filename.replace('.', '_')
    filename = filename + f'-runs_{params.n_runs}'
    filename = filename + f'-rounds_{params.n_rounds}'

    csv_file = filename + '-params.csv'

    with open(csv_file, 'w') as csvfile:
        writer = csv.writer(csvfile)
        for key, value in params.__dict__.items():
            writer.writerow([key, value])

    scale = 1/n_variables

    if params.algo == 'UCB':
        learner = SumUCB(n_arms, n_variables, radius=scale*params.radius)

    if params.algo == 'UpUCB':
        learner = UpUCB(n_arms,
                        n_variables,
                        affected_sets,
                        baseline,
                        radius=scale*params.radius)

    if params.algo == 'UpUCB_L':
        learner = UpUCB_L(n_arms,
                          n_variables,
                          max(cluster_sizes),
                          baseline,
                          radius=scale*params.radius)

    if params.algo == 'TS':
        prior_mean = np.mean(bandit.rewards)
        prior_var = np.var(bandit.rewards)
        learner = GaussTS(n_arms, params.radius, prior_mean, prior_var)

    args = []
    for i in range(params.random_seed, params.random_seed + params.n_runs):
        args.append((bandit, learner, params.n_rounds, params.print_step, i))
    results = p_map(run_para, args)
    regrets = np.cumsum(np.vstack([result[0] for result in results]), axis=1)
    regrets = regrets[:, ::params.save_every]
    arm_his = np.vstack([result[1] for result in results])
    arm_his = arm_his[:, ::params.save_every]

    np.save(filename + '-regrets.npy', regrets)
    np.save(filename + '-arm_his.npy', arm_his)


def run(bandit, learner, n_rounds, print_step, rng):
    bandit = deepcopy(bandit)
    bandit.init_feedback(rng)
    learner = deepcopy(learner)
    learner.set_rng(rng)
    regrets, arm_his = interact(bandit, learner, n_rounds, print_step)
    return regrets, arm_his


def run_para(args):
    regrets, arm_his = run(*args)
    return regrets, arm_his


if __name__ == '__main__':

    parser = ArgumentParser()

    parser.add_argument('--data_dir', type=str, default='save/bernoulli_criteo/data/')
    parser.add_argument('--save_dir', type=str, default='save/bernoulli_criteo/results/')
    parser.add_argument('--save_every', type=int, default=1)

    parser.add_argument('--no_baseline', dest='use_baseline', action='store_false')
    parser.set_defaults(use_baseline=True)

    parser.add_argument('--algo', type=str, default='UpUCB')
    parser.add_argument('--radius', type=float, default=3)

    parser.add_argument('--n_runs', type=int, default=3)
    parser.add_argument('--n_rounds', type=int, default=100)

    parser.add_argument('--print_step', type=int, default=100)

    parser.add_argument('--random_seed', type=int, default=3)

    params = parser.parse_args()

    if not os.path.exists(params.save_dir):
        os.makedirs(params.save_dir)

    main(params)
